查看原文
其他

GPU上的高效softmax近似

NeuralTalk 2022-11-28

The following article is from 雨石记 Author 张雨石

在语言模型的训练中,词表的大小往往是一个瓶颈,因为要计算context在所有词语上的概率分布。这个分布则是由softmax计算得到的。

论文[1]针对softmax做了一些优化,尤其对GPU环境做了适配,将softmax的计算提升2倍到10倍。

Overall

一般针对softmax的优化可以分为三类:

  • 损失函数的近似,比如,使用层次化+霍夫曼编码的方法进行优化。
  • 基于采样的方法,基于unigram或者bigram的频次对softmax进行采样,从而加速计算。
  • 自归一化方法,计算出来的logits做完指数后直接就是概率,无需做softmax里的归一化。

本文所讨论的方法属于第一类。

语言模型回顾

语言模型,定义就是基于前面的context,去预测下一个词语。

基于统计的语言模型会控制context中的词语数,比如bigram就是context里只有一个词,trigram则是context里包含两个词语。

而现在基于神经网络的语言模型则没有这个限制,比如LSTM、Transformer等理论上可以看到context里所有的词语。

之所以说是理论上,是因为往往为了速度考虑,不会看很长,比如Bert的长度限制就是512.

而概率分布则是通过softmax来计算:

对softmax进行加速的比较流行的方法就是层次化softmax,即对词表进行分组,先计算组的概率,然后再计算组内每个词的概率。而最后词语的概率则是组概率乘以组内词语概率。

本文所介绍的方法就是基于层次化softmax的。

GPU上的矩阵计算

Softmax的计算分为两部分: 矩阵计算和指数计算。其中,矩阵计算占的比重比较大。

在GPU上,矩阵乘法消耗的时间有这样一个规律,两个矩阵,W1和W2,在固定W1的维度和W2的第一维后,计算时间和W2第二维的关系如下。

可以看到,k=50之前,计算时间几乎是一个常量,而之后则是一个线性增长的趋势。

在考虑到batch_size,在GPU上计算softmax的时间可以用下面的公式来表示,其中k0B0代表的就是那个趋势变化点。

词表聚类

词语的分布服从二八定律,具体来说就是出现频率最高的20%的词语占到了全部词语的87%。因此,这里比较自然的将词语分为两部分,head部分和tail部分。其中 |Vh| << |Vt|,而P(Vh) >> P(Vt)。

分为这两类后,在层次softmax中,就有两种方法去构建层次:

  • head和tail组分别有一个父节点,然后两个父节点在root下面。
  • tail有一个父节点,父节点在root下面,head组的所有词语都在root下面。

根据实验,第一种方法会导致5%~10%的performance降低。因此,论文采用了第二种策略。

二分组下的时间最小化

结构定下来后,还需要确定有多少词语在head组里,多少在tail组里。假设词语总数是k,head组有khead个词语,tail组有ktail个词语。显然,khead+ktail = k。

那么需要的时间就是:

时间与khead的关系是:

可以看到,一个好的个数的选择,相对于full Softmax可以带来5x的速度提升。

稀疏词语的维度

在计算softmax的时候,相当于context的embedding和词语的embedding去做相似度计算。但因为出现次数少的词语信息不多,所以没有必要有很高的embedding size。

所以,此时可以先对它们进行降维,用一个projection就可以实现,然后再进行矩阵计算。

通常来说,可以让dt=d/4。

二分组到多分组

tail组里的词语可能会非常多,因此,还可以继续分组。如下图:

此时,时间复杂度如下,其中pi是分组中所有词语的概率和。

假设每个分组的kB >= k0B0。那么这个公式可以变成:

其中,J代表分组个数,c是计算时间中较平k较小时的那个常量值。

此时,给定J的情况下,如果要觉得每个分组中有多少词语,需要用动态规划算法。

计算时间和J的关系如下:

实验

从下面两个图中可以看出,本文提出的方法速度快的同时收敛也快。

参考文献

  • [1]. Joulin, A., Cissé, M., Grangier, D., & Jégou, H. (2017, July). Efficient softmax approximation for gpus. In International Conference on Machine Learning (pp. 1302-1310). PMLR.


您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存